{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8033b2eb",
   "metadata": {},
   "source": [
    "# Deploying Quantization Aware Trained models in INT8 using Torch-TensorRT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69ec3ca7",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "Quantization Aware training (QAT) simulates quantization during training by quantizing weights and activation layers. This will help to reduce the loss in accuracy when we convert the network trained in FP32 to INT8 for faster inference. QAT introduces additional nodes in the graph which will be used to learn the dynamic ranges of weights and activation layers. In this notebook, we illustrate the following steps from training to inference of a QAT model in Torch-TensorRT.\n",
    "\n",
    "1. [Requirements](#1)\n",
    "2. [VGG16 Overview](#2)\n",
    "3. [Training a baseline VGG16 model](#3)\n",
    "4. [Apply Quantization](#4)\n",
    "5. [Model calibration](#5)\n",
    "6. [Quantization Aware training](#6)\n",
    "7. [Export to Torchscript](#7)\n",
    "8. [Inference using Torch-TensorRT](#8)\n",
    "8. [References](#8)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79655ea8",
   "metadata": {},
   "source": [
    "<a id=\"1\"></a>\n",
    "##  1. Requirements\n",
    "Please install the <a href=\"https://github.com/NVIDIA/Torch-TensorRT/tree/master/examples/int8/training/vgg16#prequisites\">required dependencies</a> and import these libraries accordingly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6493e915",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1.0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torch.utils.data as data\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch_tensorrt\n",
    "\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "import pytorch_quantization\n",
    "from pytorch_quantization import nn as quant_nn\n",
    "from pytorch_quantization import quant_modules\n",
    "from pytorch_quantization.tensor_quant import QuantDescriptor\n",
    "from pytorch_quantization import calib\n",
    "from tqdm import tqdm\n",
    "\n",
    "print(pytorch_quantization.__version__)\n",
    "\n",
    "import os\n",
    "import sys\n",
    "sys.path.insert(0, \"../examples/int8/training/vgg16\")\n",
    "from vgg16 import vgg16\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4de5060a",
   "metadata": {},
   "source": [
    "<a id=\"2\"></a>\n",
    "##  2. VGG16 Overview\n",
    "### Very Deep Convolutional Networks for Large-Scale Image Recognition\n",
    "VGG is one of the earliest family of image classification networks that first used small (3x3) convolution filters and achieved significant improvements on ImageNet recognition challenge. The network architecture looks as follows\n",
    "<img src=\"https://neurohive.io/wp-content/uploads/2018/11/vgg16-1-e1542731207177.png\">\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a5afc49",
   "metadata": {},
   "source": [
    "<a id=\"3\"></a>\n",
    "##  3. Training a baseline VGG16 model\n",
    "We train VGG16 on CIFAR10 dataset. Define training and testing datasets and dataloaders. This will download the CIFAR 10 data in your `data` directory. Data preprocessing is performed using `torchvision` transforms. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5d2c4c45",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "\n",
    "# ========== Define Training dataset and dataloaders =============#\n",
    "training_dataset = datasets.CIFAR10(root='./data',\n",
    "                                        train=True,\n",
    "                                        download=True,\n",
    "                                        transform=transforms.Compose([\n",
    "                                            transforms.RandomCrop(32, padding=4),\n",
    "                                            transforms.RandomHorizontalFlip(),\n",
    "                                            transforms.ToTensor(),\n",
    "                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "                                        ]))\n",
    "\n",
    "training_dataloader = torch.utils.data.DataLoader(training_dataset,\n",
    "                                                      batch_size=32,\n",
    "                                                      shuffle=True,\n",
    "                                                      num_workers=2)\n",
    "\n",
    "# ========== Define Testing dataset and dataloaders =============#\n",
    "testing_dataset = datasets.CIFAR10(root='./data',\n",
    "                                   train=False,\n",
    "                                   download=True,\n",
    "                                   transform=transforms.Compose([\n",
    "                                       transforms.ToTensor(),\n",
    "                                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "                                   ]))\n",
    "\n",
    "testing_dataloader = torch.utils.data.DataLoader(testing_dataset,\n",
    "                                                 batch_size=16,\n",
    "                                                 shuffle=False,\n",
    "                                                 num_workers=2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2bd092b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, dataloader, crit, opt, epoch):\n",
    "#     global writer\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "    for batch, (data, labels) in enumerate(dataloader):\n",
    "        data, labels = data.cuda(), labels.cuda(non_blocking=True)\n",
    "        opt.zero_grad()\n",
    "        out = model(data)\n",
    "        loss = crit(out, labels)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "        if batch % 500 == 499:\n",
    "            print(\"Batch: [%5d | %5d] loss: %.3f\" % (batch + 1, len(dataloader), running_loss / 100))\n",
    "            running_loss = 0.0\n",
    "        \n",
    "def test(model, dataloader, crit, epoch):\n",
    "    global writer\n",
    "    global classes\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    loss = 0.0\n",
    "    class_probs = []\n",
    "    class_preds = []\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for data, labels in dataloader:\n",
    "            data, labels = data.cuda(), labels.cuda(non_blocking=True)\n",
    "            out = model(data)\n",
    "            loss += crit(out, labels)\n",
    "            preds = torch.max(out, 1)[1]\n",
    "            class_probs.append([F.softmax(i, dim=0) for i in out])\n",
    "            class_preds.append(preds)\n",
    "            total += labels.size(0)\n",
    "            correct += (preds == labels).sum().item()\n",
    "\n",
    "    test_probs = torch.cat([torch.stack(batch) for batch in class_probs])\n",
    "    test_preds = torch.cat(class_preds)\n",
    "\n",
    "    return loss / total, correct / total\n",
    "\n",
    "def save_checkpoint(state, ckpt_path=\"checkpoint.pth\"):\n",
    "    torch.save(state, ckpt_path)\n",
    "    print(\"Checkpoint saved\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c80a86cc",
   "metadata": {},
   "source": [
    "*Define the VGG model that we are going to perfom QAT on.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8c564b8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CIFAR 10 has 10 classes\n",
    "model = vgg16(num_classes=len(classes), init_weights=False)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "abc00452",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Declare Learning rate\n",
    "lr = 0.1\n",
    "state = {}\n",
    "state[\"lr\"] = lr\n",
    "\n",
    "# Use cross entropy loss for classification and SGD optimizer\n",
    "crit = nn.CrossEntropyLoss()\n",
    "opt = optim.SGD(model.parameters(), lr=state[\"lr\"], momentum=0.9, weight_decay=1e-4)\n",
    "\n",
    "\n",
    "# Adjust learning rate based on epoch number\n",
    "def adjust_lr(optimizer, epoch):\n",
    "    global state\n",
    "    new_lr = lr * (0.5**(epoch // 12)) if state[\"lr\"] > 1e-7 else state[\"lr\"]\n",
    "    if new_lr != state[\"lr\"]:\n",
    "        state[\"lr\"] = new_lr\n",
    "        print(\"Updating learning rate: {}\".format(state[\"lr\"]))\n",
    "        for param_group in optimizer.param_groups:\n",
    "            param_group[\"lr\"] = state[\"lr\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d80865a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: [    1 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 13.288\n",
      "Batch: [ 1000 |  1563] loss: 11.345\n",
      "Batch: [ 1500 |  1563] loss: 11.008\n",
      "Test Loss: 0.13388 Test Acc: 13.23%\n",
      "Epoch: [    2 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 10.742\n",
      "Batch: [ 1000 |  1563] loss: 10.311\n",
      "Batch: [ 1500 |  1563] loss: 10.141\n",
      "Test Loss: 0.11888 Test Acc: 23.96%\n",
      "Epoch: [    3 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 9.877\n",
      "Batch: [ 1000 |  1563] loss: 9.821\n",
      "Batch: [ 1500 |  1563] loss: 9.818\n",
      "Test Loss: 0.11879 Test Acc: 24.68%\n",
      "Epoch: [    4 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 9.677\n",
      "Batch: [ 1000 |  1563] loss: 9.613\n",
      "Batch: [ 1500 |  1563] loss: 9.504\n",
      "Test Loss: 0.11499 Test Acc: 23.68%\n",
      "Epoch: [    5 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 9.560\n",
      "Batch: [ 1000 |  1563] loss: 9.536\n",
      "Batch: [ 1500 |  1563] loss: 9.309\n",
      "Test Loss: 0.10990 Test Acc: 27.84%\n",
      "Epoch: [    6 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 9.254\n",
      "Batch: [ 1000 |  1563] loss: 9.234\n",
      "Batch: [ 1500 |  1563] loss: 9.188\n",
      "Test Loss: 0.11594 Test Acc: 23.29%\n",
      "Epoch: [    7 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 9.141\n",
      "Batch: [ 1000 |  1563] loss: 9.110\n",
      "Batch: [ 1500 |  1563] loss: 9.013\n",
      "Test Loss: 0.10732 Test Acc: 29.24%\n",
      "Epoch: [    8 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 9.120\n",
      "Batch: [ 1000 |  1563] loss: 9.086\n",
      "Batch: [ 1500 |  1563] loss: 8.948\n",
      "Test Loss: 0.10732 Test Acc: 27.24%\n",
      "Epoch: [    9 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 8.941\n",
      "Batch: [ 1000 |  1563] loss: 8.997\n",
      "Batch: [ 1500 |  1563] loss: 9.028\n",
      "Test Loss: 0.11299 Test Acc: 25.52%\n",
      "Epoch: [   10 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 8.927\n",
      "Batch: [ 1000 |  1563] loss: 8.837\n",
      "Batch: [ 1500 |  1563] loss: 8.860\n",
      "Test Loss: 0.10130 Test Acc: 34.61%\n",
      "Epoch: [   11 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 8.953\n",
      "Batch: [ 1000 |  1563] loss: 8.738\n",
      "Batch: [ 1500 |  1563] loss: 8.724\n",
      "Test Loss: 0.10018 Test Acc: 32.27%\n",
      "Epoch: [   12 /    25] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 8.721\n",
      "Batch: [ 1000 |  1563] loss: 8.716\n",
      "Batch: [ 1500 |  1563] loss: 8.701\n",
      "Test Loss: 0.10070 Test Acc: 29.57%\n",
      "Updating learning rate: 0.05\n",
      "Epoch: [   13 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 7.944\n",
      "Batch: [ 1000 |  1563] loss: 7.649\n",
      "Batch: [ 1500 |  1563] loss: 7.511\n",
      "Test Loss: 0.08555 Test Acc: 44.62%\n",
      "Epoch: [   14 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 7.057\n",
      "Batch: [ 1000 |  1563] loss: 6.944\n",
      "Batch: [ 1500 |  1563] loss: 6.687\n",
      "Test Loss: 0.08331 Test Acc: 52.27%\n",
      "Epoch: [   15 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 6.470\n",
      "Batch: [ 1000 |  1563] loss: 6.439\n",
      "Batch: [ 1500 |  1563] loss: 6.126\n",
      "Test Loss: 0.07266 Test Acc: 58.02%\n",
      "Epoch: [   16 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 5.834\n",
      "Batch: [ 1000 |  1563] loss: 5.801\n",
      "Batch: [ 1500 |  1563] loss: 5.622\n",
      "Test Loss: 0.06340 Test Acc: 65.17%\n",
      "Epoch: [   17 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 5.459\n",
      "Batch: [ 1000 |  1563] loss: 5.442\n",
      "Batch: [ 1500 |  1563] loss: 5.314\n",
      "Test Loss: 0.05945 Test Acc: 67.22%\n",
      "Epoch: [   18 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 5.071\n",
      "Batch: [ 1000 |  1563] loss: 5.145\n",
      "Batch: [ 1500 |  1563] loss: 5.063\n",
      "Test Loss: 0.06567 Test Acc: 64.46%\n",
      "Epoch: [   19 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 4.796\n",
      "Batch: [ 1000 |  1563] loss: 4.781\n",
      "Batch: [ 1500 |  1563] loss: 4.732\n",
      "Test Loss: 0.05374 Test Acc: 71.87%\n",
      "Epoch: [   20 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 4.568\n",
      "Batch: [ 1000 |  1563] loss: 4.564\n",
      "Batch: [ 1500 |  1563] loss: 4.484\n",
      "Test Loss: 0.05311 Test Acc: 71.12%\n",
      "Epoch: [   21 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 4.385\n",
      "Batch: [ 1000 |  1563] loss: 4.302\n",
      "Batch: [ 1500 |  1563] loss: 4.285\n",
      "Test Loss: 0.05080 Test Acc: 74.29%\n",
      "Epoch: [   22 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 4.069\n",
      "Batch: [ 1000 |  1563] loss: 4.105\n",
      "Batch: [ 1500 |  1563] loss: 4.096\n",
      "Test Loss: 0.04807 Test Acc: 75.20%\n",
      "Epoch: [   23 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 3.959\n",
      "Batch: [ 1000 |  1563] loss: 3.898\n",
      "Batch: [ 1500 |  1563] loss: 3.916\n",
      "Test Loss: 0.04743 Test Acc: 75.81%\n",
      "Epoch: [   24 /    25] LR: 0.050000\n",
      "Batch: [  500 |  1563] loss: 3.738\n",
      "Batch: [ 1000 |  1563] loss: 3.847\n",
      "Batch: [ 1500 |  1563] loss: 3.797\n",
      "Test Loss: 0.04609 Test Acc: 76.42%\n",
      "Updating learning rate: 0.025\n",
      "Epoch: [   25 /    25] LR: 0.025000\n",
      "Batch: [  500 |  1563] loss: 2.952\n",
      "Batch: [ 1000 |  1563] loss: 2.906\n",
      "Batch: [ 1500 |  1563] loss: 2.735\n",
      "Test Loss: 0.03466 Test Acc: 82.00%\n",
      "Checkpoint saved\n"
     ]
    }
   ],
   "source": [
    "# Train the model for 25 epochs to get ~80% accuracy.\n",
    "num_epochs=25\n",
    "for epoch in range(num_epochs):\n",
    "    adjust_lr(opt, epoch)\n",
    "    print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, state[\"lr\"]))\n",
    "\n",
    "    train(model, training_dataloader, crit, opt, epoch)\n",
    "    test_loss, test_acc = test(model, testing_dataloader, crit, epoch)\n",
    "\n",
    "    print(\"Test Loss: {:.5f} Test Acc: {:.2f}%\".format(test_loss, 100 * test_acc))\n",
    "    \n",
    "save_checkpoint({'epoch': epoch + 1,\n",
    "                 'model_state_dict': model.state_dict(),\n",
    "                 'acc': test_acc,\n",
    "                 'opt_state_dict': opt.state_dict(),\n",
    "                 'state': state},\n",
    "                ckpt_path=\"vgg16_base_ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1044537",
   "metadata": {},
   "source": [
    "<a id=\"4\"></a>\n",
    "##  4. Apply Quantization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c33b7f4e",
   "metadata": {},
   "source": [
    "`quant_modules.initialize()` will ensure quantized version of modules will be called instead of original modules. For example, when you define a model with convolution, linear, pooling layers, `QuantConv2d`, `QuantLinear` and `QuantPooling` will be called. `QuantConv2d` basically wraps quantizer nodes around inputs and weights of regular `Conv2d`. Please refer to all the <a href=\"https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization/pytorch_quantization/nn/modules\">quantized modules</a> in pytorch-quantization toolkit for more information. A `QuantConv2d` is represented in `pytorch-quantization` toolkit as follows.\n",
    "\n",
    "```\n",
    "def forward(self, input):\n",
    "        # the actual quantization happens in the next level of the class hierarchy\n",
    "        quant_input, quant_weight = self._quant(input)\n",
    "\n",
    "        if self.padding_mode == 'circular':\n",
    "            expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,\n",
    "                                (self.padding[0] + 1) // 2, self.padding[0] // 2)\n",
    "            output = F.conv2d(F.pad(quant_input, expanded_padding, mode='circular'),\n",
    "                              quant_weight, self.bias, self.stride,\n",
    "                              _pair(0), self.dilation, self.groups)\n",
    "        else:\n",
    "            output = F.conv2d(quant_input, quant_weight, self.bias, self.stride, self.padding, self.dilation,\n",
    "                              self.groups)\n",
    "\n",
    "        return output\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "985dc59e",
   "metadata": {},
   "outputs": [],
   "source": [
    "quant_modules.initialize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "164ce8cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# All the regular conv, FC layers will be converted to their quantozed counterparts due to quant_modules.initialize()\n",
    "qat_model = vgg16(num_classes=len(classes), init_weights=False)\n",
    "qat_model = qat_model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8e5f7fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# vgg16_base_ckpt is the checkpoint generated from Step 3 : Training a baseline VGG16 model.\n",
    "ckpt = torch.load(\"./vgg16_base_ckpt\")\n",
    "modified_state_dict={}\n",
    "for key, val in ckpt[\"model_state_dict\"].items():\n",
    "    # Remove 'module.' from the key names\n",
    "    if key.startswith('module'):\n",
    "        modified_state_dict[key[7:]] = val\n",
    "    else:\n",
    "        modified_state_dict[key] = val\n",
    "\n",
    "# Load the pre-trained checkpoint\n",
    "qat_model.load_state_dict(modified_state_dict)\n",
    "opt.load_state_dict(ckpt[\"opt_state_dict\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f8a74e8",
   "metadata": {},
   "source": [
    "<a id=\"5\"></a>\n",
    "##  5. Model Calibration"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2a321f9",
   "metadata": {},
   "source": [
    "The quantizer nodes introduced in the model around desired layers capture the dynamic range (min_value, max_value) that is observed by the layer. Calibration is the process of computing the dynamic range of these layers by passing calibration data, which is usually a subset of training or validation data. There are different ways of calibration: `max`, `histogram` and `entropy`. We use `max` calibration technique as it is simple and effective. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "039423dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_amax(model, **kwargs):\n",
    "    # Load calib result\n",
    "    for name, module in model.named_modules():\n",
    "        if isinstance(module, quant_nn.TensorQuantizer):\n",
    "            if module._calibrator is not None:\n",
    "                if isinstance(module._calibrator, calib.MaxCalibrator):\n",
    "                    module.load_calib_amax()\n",
    "                else:\n",
    "                    module.load_calib_amax(**kwargs)\n",
    "            print(F\"{name:40}: {module}\")\n",
    "    model.cuda()\n",
    "\n",
    "def collect_stats(model, data_loader, num_batches):\n",
    "    \"\"\"Feed data to the network and collect statistics\"\"\"\n",
    "    # Enable calibrators\n",
    "    for name, module in model.named_modules():\n",
    "        if isinstance(module, quant_nn.TensorQuantizer):\n",
    "            if module._calibrator is not None:\n",
    "                module.disable_quant()\n",
    "                module.enable_calib()\n",
    "            else:\n",
    "                module.disable()\n",
    "\n",
    "    # Feed data to the network for collecting stats\n",
    "    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):\n",
    "        model(image.cuda())\n",
    "        if i >= num_batches:\n",
    "            break\n",
    "\n",
    "    # Disable calibrators\n",
    "    for name, module in model.named_modules():\n",
    "        if isinstance(module, quant_nn.TensorQuantizer):\n",
    "            if module._calibrator is not None:\n",
    "                module.enable_quant()\n",
    "                module.disable_calib()\n",
    "            else:\n",
    "                module.enable()\n",
    "\n",
    "def calibrate_model(model, model_name, data_loader, num_calib_batch, calibrator, hist_percentile, out_dir):\n",
    "    \"\"\"\n",
    "        Feed data to the network and calibrate.\n",
    "        Arguments:\n",
    "            model: classification model\n",
    "            model_name: name to use when creating state files\n",
    "            data_loader: calibration data set\n",
    "            num_calib_batch: amount of calibration passes to perform\n",
    "            calibrator: type of calibration to use (max/histogram)\n",
    "            hist_percentile: percentiles to be used for historgram calibration\n",
    "            out_dir: dir to save state files in\n",
    "    \"\"\"\n",
    "\n",
    "    if num_calib_batch > 0:\n",
    "        print(\"Calibrating model\")\n",
    "        with torch.no_grad():\n",
    "            collect_stats(model, data_loader, num_calib_batch)\n",
    "\n",
    "        if not calibrator == \"histogram\":\n",
    "            compute_amax(model, method=\"max\")\n",
    "            calib_output = os.path.join(\n",
    "                out_dir,\n",
    "                F\"{model_name}-max-{num_calib_batch*data_loader.batch_size}.pth\")\n",
    "            torch.save(model.state_dict(), calib_output)\n",
    "        else:\n",
    "            for percentile in hist_percentile:\n",
    "                print(F\"{percentile} percentile calibration\")\n",
    "                compute_amax(model, method=\"percentile\")\n",
    "                calib_output = os.path.join(\n",
    "                    out_dir,\n",
    "                    F\"{model_name}-percentile-{percentile}-{num_calib_batch*data_loader.batch_size}.pth\")\n",
    "                torch.save(model.state_dict(), calib_output)\n",
    "\n",
    "            for method in [\"mse\", \"entropy\"]:\n",
    "                print(F\"{method} calibration\")\n",
    "                compute_amax(model, method=method)\n",
    "                calib_output = os.path.join(\n",
    "                    out_dir,\n",
    "                    F\"{model_name}-{method}-{num_calib_batch*data_loader.batch_size}.pth\")\n",
    "                torch.save(model.state_dict(), calib_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "78504a6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calibrating model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████| 32/32 [00:00<00:00, 96.04it/s]\n",
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W1109 04:01:43.512364 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.513354 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.514046 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.514638 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.515270 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.515859 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.516441 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.517009 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.517600 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.518167 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.518752 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.519333 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.519911 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.520473 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.521038 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.521596 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.522170 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.522742 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.523360 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.523957 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.524581 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.525059 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.525366 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.525675 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.525962 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.526257 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.526566 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.526885 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.527188 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.527489 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.527792 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.528097 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.528387 139704147265344 tensor_quantizer.py:173] Disable MaxCalibrator\n",
      "W1109 04:01:43.528834 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.529163 139704147265344 tensor_quantizer.py:238] Call .cuda() if running on GPU after loading calibrated amax.\n",
      "W1109 04:01:43.532748 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([64, 1, 1, 1]).\n",
      "W1109 04:01:43.533468 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.534033 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([64, 1, 1, 1]).\n",
      "W1109 04:01:43.534684 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.535320 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([128, 1, 1, 1]).\n",
      "W1109 04:01:43.535983 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.536569 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([128, 1, 1, 1]).\n",
      "W1109 04:01:43.537248 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.537833 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).\n",
      "W1109 04:01:43.538480 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.539074 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).\n",
      "W1109 04:01:43.539724 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.540307 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([256, 1, 1, 1]).\n",
      "W1109 04:01:43.540952 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.541534 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n",
      "W1109 04:01:43.542075 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.542596 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n",
      "W1109 04:01:43.543248 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.543719 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n",
      "W1109 04:01:43.544424 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.544952 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n",
      "W1109 04:01:43.545530 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.546114 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n",
      "W1109 04:01:43.546713 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.547292 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([512, 1, 1, 1]).\n",
      "W1109 04:01:43.547902 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.548453 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.549015 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([4096, 1]).\n",
      "W1109 04:01:43.549665 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.550436 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([4096, 1]).\n",
      "W1109 04:01:43.551925 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([]).\n",
      "W1109 04:01:43.553105 139704147265344 tensor_quantizer.py:237] Load calibrated amax, shape=torch.Size([10, 1]).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "features.0._input_quantizer             : TensorQuantizer(8bit narrow fake per-tensor amax=2.7537 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.0._weight_quantizer            : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0263, 2.7454](64) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.3._input_quantizer             : TensorQuantizer(8bit narrow fake per-tensor amax=27.5676 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.3._weight_quantizer            : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0169, 1.8204](64) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.7._input_quantizer             : TensorQuantizer(8bit narrow fake per-tensor amax=15.2002 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.7._weight_quantizer            : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0493, 1.3207](128) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.10._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=7.7376 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.10._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0163, 0.9624](128) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.14._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=8.8351 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.14._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0622, 0.8791](256) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.17._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=12.5746 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.17._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0505, 0.5117](256) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.20._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=9.7203 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.20._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0296, 0.5335](256) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.24._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=8.9367 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.24._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0220, 0.3763](512) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.27._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=6.6539 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.27._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0151, 0.1777](512) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.30._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=3.7099 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.30._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0087, 0.1906](512) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.34._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=4.0491 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.34._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0106, 0.1971](512) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.37._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=2.1531 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.37._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0070, 0.2305](512) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.40._input_quantizer            : TensorQuantizer(8bit narrow fake per-tensor amax=3.3631 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "features.40._weight_quantizer           : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0023, 0.4726](512) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "avgpool._input_quantizer                : TensorQuantizer(8bit narrow fake per-tensor amax=5.3550 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "classifier.0._input_quantizer           : TensorQuantizer(8bit narrow fake per-tensor amax=5.3550 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "classifier.0._weight_quantizer          : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0026, 0.5320](4096) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "classifier.3._input_quantizer           : TensorQuantizer(8bit narrow fake per-tensor amax=6.6733 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "classifier.3._weight_quantizer          : TensorQuantizer(8bit narrow fake axis=0 amax=[0.0018, 0.5172](4096) calibrator=MaxCalibrator scale=1.0 quant)\n",
      "classifier.6._input_quantizer           : TensorQuantizer(8bit narrow fake per-tensor amax=9.4352 calibrator=MaxCalibrator scale=1.0 quant)\n",
      "classifier.6._weight_quantizer          : TensorQuantizer(8bit narrow fake axis=0 amax=[0.3877, 0.5620](10) calibrator=MaxCalibrator scale=1.0 quant)\n"
     ]
    }
   ],
   "source": [
    "#Calibrate the model using max calibration technique.\n",
    "with torch.no_grad():\n",
    "    calibrate_model(\n",
    "        model=qat_model,\n",
    "        model_name=\"vgg16\",\n",
    "        data_loader=training_dataloader,\n",
    "        num_calib_batch=32,\n",
    "        calibrator=\"max\",\n",
    "        hist_percentile=[99.9, 99.99, 99.999, 99.9999],\n",
    "        out_dir=\"./\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1aa0c109",
   "metadata": {},
   "source": [
    "<a id=\"6\"></a>\n",
    "##  6. Quantization Aware Training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fe8ec11",
   "metadata": {},
   "source": [
    "In this phase, we finetune the model weights and leave the quantizer node values frozen. The dynamic ranges for each layer obtained from the calibration are kept constant while the weights of the model are finetuned to be close to the accuracy of original FP32 model (model without quantizer nodes) is preserved. Usually the finetuning of QAT model should be quick compared to the full training of the original model. Use QAT to fine-tune for around 10% of the original training schedule with an annealing learning-rate. Please refer to <a href=\"https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/\">Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT</a> for detailed recommendations. For this VGG model, it is enough to finetune for 1 epoch to get acceptable accuracy. \n",
    "During finetuning with QAT, the quantization is applied as a composition of `max`, `clamp`, `round` and `mul` ops. \n",
    "```\n",
    "# amax is absolute maximum value for an input\n",
    "# The upper bound for integer quantization (127 for int8)\n",
    "max_bound = torch.tensor((2.0**(num_bits - 1 + int(unsigned))) - 1.0, device=amax.device)\n",
    "scale = max_bound / amax\n",
    "outputs = torch.clamp((inputs * scale).round_(), min_bound, max_bound)\n",
    "```\n",
    "<a href=\"https://github.com/NVIDIA/TensorRT/blob/8.0.1/tools/pytorch-quantization/pytorch_quantization/tensor_quant.py\">tensor_quant function</a> in `pytorch_quantization` toolkit is responsible for the above tensor quantization. Usually, per channel quantization is recommended for weights, while per tensor quantization is recommended for activations in a network.\n",
    "During inference, we use `torch.fake_quantize_per_tensor_affine` and `torch.fake_quantize_per_channel_affine` to perform quantization as this is easier to convert into corresponding TensorRT operators. Please refer to next sections for more details on how these operators are exported in torchscript and converted in Torch-TensorRT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1f28d228",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Updating learning rate: 0.1\n",
      "Epoch: [    1 /     1] LR: 0.100000\n",
      "Batch: [  500 |  1563] loss: 2.635\n",
      "Batch: [ 1000 |  1563] loss: 2.655\n",
      "Batch: [ 1500 |  1563] loss: 2.646\n",
      "Test Loss: 0.03291 Test Acc: 82.98%\n",
      "Checkpoint saved\n"
     ]
    }
   ],
   "source": [
    "# Finetune the QAT model for 1 epoch\n",
    "num_epochs=1\n",
    "for epoch in range(num_epochs):\n",
    "    adjust_lr(opt, epoch)\n",
    "    print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, state[\"lr\"]))\n",
    "\n",
    "    train(qat_model, training_dataloader, crit, opt, epoch)\n",
    "    test_loss, test_acc = test(qat_model, testing_dataloader, crit, epoch)\n",
    "\n",
    "    print(\"Test Loss: {:.5f} Test Acc: {:.2f}%\".format(test_loss, 100 * test_acc))\n",
    "    \n",
    "save_checkpoint({'epoch': epoch + 1,\n",
    "                 'model_state_dict': qat_model.state_dict(),\n",
    "                 'acc': test_acc,\n",
    "                 'opt_state_dict': opt.state_dict(),\n",
    "                 'state': state},\n",
    "                ckpt_path=\"vgg16_qat_ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a4dcaa2",
   "metadata": {},
   "source": [
    "<a id=\"7\"></a>\n",
    "##  7. Export to Torchscript\n",
    "Export the model to Torch script. Trace the model and convert it into torchscript for deployment. To learn more about Torchscript, please refer to https://pytorch.org/docs/stable/jit.html. Setting `quant_nn.TensorQuantizer.use_fb_fake_quant = True` enables the QAT model to use `torch.fake_quantize_per_tensor_affine` and `torch.fake_quantize_per_channel_affine` operators instead of `tensor_quant` function to export quantization operators. In torchscript, they are represented as `aten::fake_quantize_per_tensor_affine` and `aten::fake_quantize_per_channel_affine`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3d34f526",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E1109 04:02:37.101168 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.102248 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.107194 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.107625 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.115269 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.115740 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.117969 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.118358 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.126382 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.126834 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.128674 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.129518 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.135453 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.135936 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.137858 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.138366 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.145539 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.146053 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.147871 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.148353 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.154252 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.154685 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.156558 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.157159 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.163197 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.163676 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.165549 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.165991 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.173305 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.173926 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.176034 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.176697 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.182843 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.183426 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.185377 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.185962 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.191966 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.192424 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.194325 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.194817 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.201988 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.202665 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.204763 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.205461 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.211393 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.211987 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.213899 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.214450 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.220892 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.221533 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.223519 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.224037 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.233809 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.234434 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.238212 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.239042 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.241022 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.241654 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.247820 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.248445 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.250366 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.250959 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.257248 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.257854 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.259968 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.260660 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "W1109 04:02:37.268160 139704147265344 tensor_quantizer.py:280] Use Pytorch's native experimental fake quantization.\n",
      "/opt/conda/lib/python3.8/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py:285: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  inputs, amax.item() / bound, 0,\n",
      "/opt/conda/lib/python3.8/site-packages/pytorch_quantization/nn/modules/tensor_quantizer.py:291: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  quant_dim = list(amax.shape).index(list(amax_sequeeze.shape)[0])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E1109 04:02:37.329273 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.330212 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.332529 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.333365 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.339547 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.340248 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.342257 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.342890 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.350619 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.351372 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.353470 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.354121 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.360090 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.360806 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.362803 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.363274 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.370369 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.371057 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.373071 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.373766 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.379890 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.380538 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.382532 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.383128 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.389077 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.389760 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.391815 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.392399 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.399809 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.400472 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.402399 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.402939 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.408818 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.409424 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.411513 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.412097 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.418537 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.419128 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.421343 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.421946 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.429382 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.430156 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.432259 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.433079 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.439297 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.440027 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.442149 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.442826 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.449377 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.449968 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.452122 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.452754 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.462532 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.463295 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.466963 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.467725 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.469692 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.470336 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.476204 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.476738 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.478809 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.479375 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.485666 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.486219 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.488416 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n",
      "E1109 04:02:37.488986 139704147265344 tensor_quantizer.py:120] Fake quantize mode doesn't use scale explicitly!\n"
     ]
    }
   ],
   "source": [
    "quant_nn.TensorQuantizer.use_fb_fake_quant = True\n",
    "with torch.no_grad():\n",
    "    data = iter(testing_dataloader)\n",
    "    images, _ = data.next()\n",
    "    jit_model = torch.jit.trace(qat_model, images.to(\"cuda\"))\n",
    "    torch.jit.save(jit_model, \"trained_vgg16_qat.jit.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7341418a",
   "metadata": {},
   "source": [
    "<a id=\"8\"></a>\n",
    "##  8. Inference using Torch-TensorRT\n",
    "In this phase, we run the exported torchscript graph of VGG QAT using Torch-TensorRT. Torch-TensorRT is a Pytorch-TensorRT compiler which converts Torchscript graphs into TensorRT. TensorRT 8.0 supports inference of quantization aware trained models and introduces new APIs; `QuantizeLayer` and `DequantizeLayer`. We can observe the entire VGG QAT graph quantization nodes from the debug log of Torch-TensorRT. To enable debug logging, you can set `torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Debug)`. For example, `QuantConv2d` layer from `pytorch_quantization` toolkit is represented as follows in Torchscript\n",
    "```\n",
    "%quant_input : Tensor = aten::fake_quantize_per_tensor_affine(%x, %636, %637, %638, %639)\n",
    "%quant_weight : Tensor = aten::fake_quantize_per_channel_affine(%394, %640, %641, %637, %638, %639)\n",
    "%input.2 : Tensor = aten::_convolution(%quant_input, %quant_weight, %395, %687, %688, %689, %643, %690, %642, %643, %643, %644, %644)\n",
    "```\n",
    "`aten::fake_quantize_per_*_affine` is converted into `QuantizeLayer` + `DequantizeLayer` in Torch-TensorRT internally. Please refer to <a href=\"https://github.com/NVIDIA/Torch-TensorRT/blob/master/core/conversion/converters/impl/quantization.cpp\">quantization op converters</a> in Torch-TensorRT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "aa7495e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input x.2. Assuming it is Float32. If not, specify input type explicity\n",
      "WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter\n",
      "WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter\n",
      "WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter\n",
      "WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter\n",
      "WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter\n",
      "WARNING: [Torch-TensorRT TorchScript Conversion Context] - Detected invalid timing cache, setup a local cache instead\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VGG QAT accuracy using TensorRT: 82.97%\n"
     ]
    }
   ],
   "source": [
    "qat_model = torch.jit.load(\"trained_vgg16_qat.jit.pt\").eval()\n",
    "\n",
    "compile_spec = {\"inputs\": [torch_tensorrt.Input([16, 3, 32, 32])],\n",
    "                \"enabled_precisions\": torch.int8,\n",
    "                }\n",
    "trt_mod = torch_tensorrt.compile(qat_model, **compile_spec)\n",
    "\n",
    "test_loss, test_acc = test(trt_mod, testing_dataloader, crit, 0)\n",
    "print(\"VGG QAT accuracy using TensorRT: {:.2f}%\".format(100 * test_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9df5a90e",
   "metadata": {},
   "source": [
    "### Performance benchmarking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9eb2cd2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "\n",
    "import torch.backends.cudnn as cudnn\n",
    "cudnn.benchmark = True\n",
    "\n",
    "# Helper function to benchmark the model\n",
    "def benchmark(model, input_shape=(1024, 1, 32, 32), dtype='fp32', nwarmup=50, nruns=1000):\n",
    "    input_data = torch.randn(input_shape)\n",
    "    input_data = input_data.to(\"cuda\")\n",
    "    if dtype=='fp16':\n",
    "        input_data = input_data.half()\n",
    "        \n",
    "    print(\"Warm up ...\")\n",
    "    with torch.no_grad():\n",
    "        for _ in range(nwarmup):\n",
    "            features = model(input_data)\n",
    "    torch.cuda.synchronize()\n",
    "    print(\"Start timing ...\")\n",
    "    timings = []\n",
    "    with torch.no_grad():\n",
    "        for i in range(1, nruns+1):\n",
    "            start_time = time.time()\n",
    "            output = model(input_data)\n",
    "            torch.cuda.synchronize()\n",
    "            end_time = time.time()\n",
    "            timings.append(end_time - start_time)\n",
    "            if i%100==0:\n",
    "                print('Iteration %d/%d, avg batch time %.2f ms'%(i, nruns, np.mean(timings)*1000))\n",
    "\n",
    "    print(\"Input shape:\", input_data.size())\n",
    "    print(\"Output shape:\", output.shape)\n",
    "    print('Average batch time: %.2f ms'%(np.mean(timings)*1000))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5c2514ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warm up ...\n",
      "Start timing ...\n",
      "Iteration 100/1000, avg batch time 4.83 ms\n",
      "Iteration 200/1000, avg batch time 4.83 ms\n",
      "Iteration 300/1000, avg batch time 4.83 ms\n",
      "Iteration 400/1000, avg batch time 4.83 ms\n",
      "Iteration 500/1000, avg batch time 4.83 ms\n",
      "Iteration 600/1000, avg batch time 4.83 ms\n",
      "Iteration 700/1000, avg batch time 4.83 ms\n",
      "Iteration 800/1000, avg batch time 4.83 ms\n",
      "Iteration 900/1000, avg batch time 4.83 ms\n",
      "Iteration 1000/1000, avg batch time 4.83 ms\n",
      "Input shape: torch.Size([16, 3, 32, 32])\n",
      "Output shape: torch.Size([16, 10])\n",
      "Average batch time: 4.83 ms\n"
     ]
    }
   ],
   "source": [
    "benchmark(jit_model, input_shape=(16, 3, 32, 32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c5378ed6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warm up ...\n",
      "Start timing ...\n",
      "Iteration 100/1000, avg batch time 1.87 ms\n",
      "Iteration 200/1000, avg batch time 1.84 ms\n",
      "Iteration 300/1000, avg batch time 1.85 ms\n",
      "Iteration 400/1000, avg batch time 1.83 ms\n",
      "Iteration 500/1000, avg batch time 1.82 ms\n",
      "Iteration 600/1000, avg batch time 1.81 ms\n",
      "Iteration 700/1000, avg batch time 1.81 ms\n",
      "Iteration 800/1000, avg batch time 1.80 ms\n",
      "Iteration 900/1000, avg batch time 1.80 ms\n",
      "Iteration 1000/1000, avg batch time 1.79 ms\n",
      "Input shape: torch.Size([16, 3, 32, 32])\n",
      "Output shape: torch.Size([16, 10])\n",
      "Average batch time: 1.79 ms\n"
     ]
    }
   ],
   "source": [
    "benchmark(trt_mod, input_shape=(16, 3, 32, 32))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6a5ec1c",
   "metadata": {},
   "source": [
    "<a id=\"9\"></a>\n",
    "##  9. References\n",
    "* <a href=\"https://arxiv.org/pdf/1409.1556.pdf\">Very Deep Convolution Networks for large scale Image Recognition</a>\n",
    "* <a href=\"https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/\">Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT</a>\n",
    "* <a href=\"https://github.com/NVIDIA/Torch-TensorRT/tree/master/examples/int8/training/vgg16#quantization-aware-fine-tuning-for-trying-out-qat-workflows\">QAT workflow for VGG16</a>\n",
    "* <a href=\"https://github.com/NVIDIA/Torch-TensorRT/tree/master/examples/int8/qat\">Deploying VGG QAT model in C++ using Torch-TensorRT</a>\n",
    "* <a href=\"https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization\">Pytorch-quantization toolkit from NVIDIA</a>\n",
    "* <a href=\"https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/userguide.html\">Pytorch quantization toolkit userguide</a>\n",
    "* <a href=\"https://arxiv.org/pdf/2004.09602.pdf\">Quantization basics</a>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
